// Copyright (C) 2024 Kumo inc.
// Author: Jeff.li lijippy@163.com
// All rights reserved.
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program.  If not, see <https://www.gnu.org/licenses/>.
//


#include <kllm/core/sampling.h>
#include <kllm/core/km_context.h>
#include <cmath>
#include <unordered_map>


namespace kllm {
    // the ring buffer works similarly to std::deque, but with a fixed capacity
    // TODO: deduplicate with llama-impl.h
    template<typename T>
    struct ring_buffer {
        ring_buffer(size_t cap) : capacity(cap), data(cap) {}

        T &front() {
            if (sz == 0) {
                throw std::runtime_error("ring buffer is empty");
            }
            return data[first];
        }

        const T &front() const {
            if (sz == 0) {
                throw std::runtime_error("ring buffer is empty");
            }
            return data[first];
        }

        T &back() {
            if (sz == 0) {
                throw std::runtime_error("ring buffer is empty");
            }
            return data[pos];
        }

        const T &back() const {
            if (sz == 0) {
                throw std::runtime_error("ring buffer is empty");
            }
            return data[pos];
        }

        void push_back(const T &value) {
            if (sz == capacity) {
                // advance the start when buffer is full
                first = (first + 1) % capacity;
            } else {
                sz++;
            }
            data[pos] = value;
            pos = (pos + 1) % capacity;
        }

        T pop_front() {
            if (sz == 0) {
                throw std::runtime_error("ring buffer is empty");
            }
            T value = data[first];
            first = (first + 1) % capacity;
            sz--;
            return value;
        }

        const T &rat(size_t i) const {
            if (i >= sz) {
                throw std::runtime_error("ring buffer: index out of bounds");
            }
            return data[(first + sz - i - 1) % capacity];
        }

        std::vector<T> to_vector() const {
            std::vector<T> result;
            result.reserve(sz);
            for (size_t i = 0; i < sz; i++) {
                result.push_back(data[(first + i) % capacity]);
            }
            return result;
        }

        void clear() {
            // here only reset the status of the buffer
            sz = 0;
            first = 0;
            pos = 0;
        }

        bool empty() const {
            return sz == 0;
        }

        size_t size() const {
            return sz;
        }

        size_t capacity = 0;
        size_t sz = 0;
        size_t first = 0;
        size_t pos = 0;
        std::vector<T> data;
    };

    struct common_sampler {
        InternalSamplerParams params;

        struct llama_sampler *grmr;
        struct llama_sampler *chain;

        ring_buffer<llama_token> prev;

        std::vector<llama_token_data> cur;

        llama_token_data_array cur_p;

        void set_logits(struct llama_context *ctx, int idx) {
            const auto *logits = llama_get_logits_ith(ctx, idx);

            const int n_vocab = llama_n_vocab(llama_get_model(ctx));

            cur.resize(n_vocab);

            for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
                cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
            }

            cur_p = {cur.data(), cur.size(), -1, false};
        }
    };

    struct common_sampler *
    common_sampler_init(const struct llama_model *model, const struct InternalSamplerParams &params) {
        llama_sampler_chain_params lparams = llama_sampler_chain_default_params();

        lparams.no_perf = params.no_perf;

        auto *result = new common_sampler{
                /* .params = */ params,
                /* .grmr   = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"),
                /* .chain  = */ llama_sampler_chain_init(lparams),
                /* .prev   = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
                /* .cur    = */ {},
                /* .cur_p  = */ {},
        };

        llama_sampler_chain_add(result->chain,
                                llama_sampler_init_logit_bias(
                                        llama_n_vocab(model),
                                        params.logit_bias.size(),
                                        params.logit_bias.data()));

        llama_sampler_chain_add(result->chain,
                                llama_sampler_init_penalties(
                                        llama_n_vocab(model),
                                        llama_token_eos(model),
                                        llama_token_nl(model),
                                        params.penalty_last_n,
                                        params.penalty_repeat,
                                        params.penalty_freq,
                                        params.penalty_present,
                                        params.penalize_nl,
                                        params.ignore_eos));

        if (params.mirostat == 0) {
            for (const auto &cnstr: params.samplers) {
                switch (cnstr) {
                    case COMMON_SAMPLER_TYPE_DRY: {
                        std::vector<const char *> c_breakers;
                        c_breakers.reserve(params.dry_sequence_breakers.size());
                        for (const auto &str: params.dry_sequence_breakers) {
                            c_breakers.push_back(str.c_str());
                        }

                        llama_sampler_chain_add(result->chain,
                                                llama_sampler_init_dry(model, params.dry_multiplier, params.dry_base,
                                                                       params.dry_allowed_length,
                                                                       params.dry_penalty_last_n, c_breakers.data(),
                                                                       c_breakers.size()));
                    }
                        break;
                    case COMMON_SAMPLER_TYPE_TOP_K:
                        llama_sampler_chain_add(result->chain, llama_sampler_init_top_k(params.top_k));
                        break;
                    case COMMON_SAMPLER_TYPE_TOP_P:
                        llama_sampler_chain_add(result->chain, llama_sampler_init_top_p(params.top_p, params.min_keep));
                        break;
                    case COMMON_SAMPLER_TYPE_MIN_P:
                        llama_sampler_chain_add(result->chain, llama_sampler_init_min_p(params.min_p, params.min_keep));
                        break;
                    case COMMON_SAMPLER_TYPE_XTC:
                        llama_sampler_chain_add(result->chain,
                                                llama_sampler_init_xtc(params.xtc_probability, params.xtc_threshold,
                                                                       params.min_keep, params.seed));
                        break;
                    case COMMON_SAMPLER_TYPE_TYPICAL_P:
                        llama_sampler_chain_add(result->chain,
                                                llama_sampler_init_typical(params.typ_p, params.min_keep));
                        break;
                    case COMMON_SAMPLER_TYPE_TEMPERATURE:
                        llama_sampler_chain_add(result->chain,
                                                llama_sampler_init_temp_ext(params.temp, params.dynatemp_range,
                                                                            params.dynatemp_exponent));
                        break;
                    case COMMON_SAMPLER_TYPE_INFILL:
                        llama_sampler_chain_add(result->chain, llama_sampler_init_infill(model));
                        break;
                    default:
                        GGML_ASSERT(false && "unknown sampler type");
                }
            }
            llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
        } else if (params.mirostat == 1) {
            llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
            llama_sampler_chain_add(result->chain,
                                    llama_sampler_init_mirostat(llama_n_vocab(model), params.seed, params.mirostat_tau,
                                                                params.mirostat_eta, 100));
        } else if (params.mirostat == 2) {
            llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
            llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau,
                                                                                  params.mirostat_eta));
        } else {
            GGML_ASSERT(false && "unknown mirostat version");
        }

        return result;
    }

    void common_sampler_free(struct common_sampler *gsmpl) {
        if (gsmpl) {
            llama_sampler_free(gsmpl->grmr);

            llama_sampler_free(gsmpl->chain);

            delete gsmpl;
        }
    }

    void common_sampler_accept(struct common_sampler *gsmpl, llama_token token, bool accept_grammar) {
        if (accept_grammar) {
            llama_sampler_accept(gsmpl->grmr, token);
        }

        llama_sampler_accept(gsmpl->chain, token);

        gsmpl->prev.push_back(token);
    }

    void common_sampler_reset(struct common_sampler *gsmpl) {
        llama_sampler_reset(gsmpl->grmr);

        llama_sampler_reset(gsmpl->chain);
    }

    struct common_sampler *common_sampler_clone(common_sampler *gsmpl) {
        return new common_sampler{
                /* .params = */ gsmpl->params,
                /* .grmr   = */ llama_sampler_clone(gsmpl->grmr),
                /* .chain  = */ llama_sampler_clone(gsmpl->chain),
                /* .prev   = */ gsmpl->prev,
                /* .cur    = */ gsmpl->cur,
                /* .cur_p  = */ gsmpl->cur_p,
        };
    }

    void common_perf_print(const struct llama_context *ctx, const struct common_sampler *gsmpl) {
        // TODO: measure grammar performance

        if (gsmpl) {
            llama_perf_sampler_print(gsmpl->chain);
        }
        if (ctx) {
            llama_perf_context_print(ctx);
        }
    }

    llama_token
    common_sampler_sample(struct common_sampler *gsmpl, struct llama_context *ctx, int idx, bool grammar_first) {
        gsmpl->set_logits(ctx, idx);

        auto &grmr = gsmpl->grmr;
        auto &chain = gsmpl->chain;
        auto &cur_p = gsmpl->cur_p; // initialized by set_logits

        if (grammar_first) {
            llama_sampler_apply(grmr, &cur_p);
        }

        llama_sampler_apply(chain, &cur_p);

        GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");

        const llama_token id = cur_p.data[cur_p.selected].id;

        if (grammar_first) {
            return id;
        }

        // check if it the sampled token fits the grammar
        {
            llama_token_data single_token_data = {id, 1.0f, 0.0f};
            llama_token_data_array single_token_data_array = {&single_token_data, 1, -1, false};

            llama_sampler_apply(grmr, &single_token_data_array);

            const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
            if (is_valid) {
                return id;
            }
        }

        // resampling:
        // if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
        gsmpl->set_logits(ctx, idx);

        llama_sampler_apply(grmr, &cur_p);
        llama_sampler_apply(chain, &cur_p);

        GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration");

        return cur_p.data[cur_p.selected].id;
    }

    uint32_t common_sampler_get_seed(const struct common_sampler *gsmpl) {
        return llama_sampler_get_seed(gsmpl->chain);
    }

// helpers

    llama_token_data_array *common_sampler_get_candidates(struct common_sampler *gsmpl) {
        return &gsmpl->cur_p;
    }

    llama_token common_sampler_last(const struct common_sampler *gsmpl) {
        return gsmpl->prev.rat(0);
    }

    std::string common_sampler_print(const struct common_sampler *gsmpl) {
        std::string result = "logits ";

        for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
            const auto *smpl = llama_sampler_chain_get(gsmpl->chain, i);
            result += std::string("-> ") + llama_sampler_name(smpl) + " ";
        }

        return result;
    }

    std::string common_sampler_prev_str(common_sampler *gsmpl, const KMContext *ctx_main, int n) {
        n = std::min(n, (int) gsmpl->prev.size());

        if (n <= 0) {
            return "";
        }

        std::string result;
        result.reserve(
                8 * n); // 8 is the average length of a token [citation needed], TODO: compute this from the vocab

        for (int i = n - 1; i >= 0; i--) {
            const llama_token id = gsmpl->prev.rat(i);

            GGML_ASSERT(id != LLAMA_TOKEN_NULL && "null token in the sampling history - should not happen");

            result += ctx_main->token_to_piece(id);
        }

        return result;
    }

    char common_sampler_type_to_chr(KaiSamplerType cnstr) {
        switch (cnstr) {
            case COMMON_SAMPLER_TYPE_DRY:
                return 'd';
            case COMMON_SAMPLER_TYPE_TOP_K:
                return 'k';
            case COMMON_SAMPLER_TYPE_TYPICAL_P:
                return 'y';
            case COMMON_SAMPLER_TYPE_TOP_P:
                return 'p';
            case COMMON_SAMPLER_TYPE_MIN_P:
                return 'm';
            case COMMON_SAMPLER_TYPE_TEMPERATURE:
                return 't';
            case COMMON_SAMPLER_TYPE_XTC:
                return 'x';
            case COMMON_SAMPLER_TYPE_INFILL:
                return 'i';
            default :
                return '?';
        }
    }

    std::string common_sampler_type_to_str(enum KaiSamplerType cnstr) {
        switch (cnstr) {
            case COMMON_SAMPLER_TYPE_DRY:
                return "dry";
            case COMMON_SAMPLER_TYPE_TOP_K:
                return "top_k";
            case COMMON_SAMPLER_TYPE_TYPICAL_P:
                return "typ_p";
            case COMMON_SAMPLER_TYPE_TOP_P:
                return "top_p";
            case COMMON_SAMPLER_TYPE_MIN_P:
                return "min_p";
            case COMMON_SAMPLER_TYPE_TEMPERATURE:
                return "temperature";
            case COMMON_SAMPLER_TYPE_XTC:
                return "xtc";
            case COMMON_SAMPLER_TYPE_INFILL:
                return "infill";
            default :
                return "";
        }
    }

    std::vector<KaiSamplerType>
    common_sampler_types_from_names(const std::vector<std::string> &names, bool allow_alt_names) {
        std::unordered_map<std::string, KaiSamplerType> sampler_canonical_name_map{
                {"dry",         COMMON_SAMPLER_TYPE_DRY},
                {"top_k",       COMMON_SAMPLER_TYPE_TOP_K},
                {"top_p",       COMMON_SAMPLER_TYPE_TOP_P},
                {"typ_p",       COMMON_SAMPLER_TYPE_TYPICAL_P},
                {"min_p",       COMMON_SAMPLER_TYPE_MIN_P},
                {"temperature", COMMON_SAMPLER_TYPE_TEMPERATURE},
                {"xtc",         COMMON_SAMPLER_TYPE_XTC},
                {"infill",      COMMON_SAMPLER_TYPE_INFILL},
        };

        // since samplers names are written multiple ways
        // make it ready for both system names and input names
        std::unordered_map<std::string, KaiSamplerType> sampler_alt_name_map{
                {"top-k",     COMMON_SAMPLER_TYPE_TOP_K},
                {"top-p",     COMMON_SAMPLER_TYPE_TOP_P},
                {"nucleus",   COMMON_SAMPLER_TYPE_TOP_P},
                {"typical-p", COMMON_SAMPLER_TYPE_TYPICAL_P},
                {"typical",   COMMON_SAMPLER_TYPE_TYPICAL_P},
                {"typ-p",     COMMON_SAMPLER_TYPE_TYPICAL_P},
                {"typ",       COMMON_SAMPLER_TYPE_TYPICAL_P},
                {"min-p",     COMMON_SAMPLER_TYPE_MIN_P},
                {"temp",      COMMON_SAMPLER_TYPE_TEMPERATURE},
        };

        std::vector<KaiSamplerType> samplers;
        samplers.reserve(names.size());

        for (const auto &name: names) {
            auto sampler = sampler_canonical_name_map.find(name);
            if (sampler != sampler_canonical_name_map.end()) {
                samplers.push_back(sampler->second);
            } else {
                if (allow_alt_names) {
                    sampler = sampler_alt_name_map.find(name);
                    if (sampler != sampler_alt_name_map.end()) {
                        samplers.push_back(sampler->second);
                    }
                }
            }
        }

        return samplers;
    }

    std::vector<KaiSamplerType> common_sampler_types_from_chars(const std::string &chars) {
        std::unordered_map<char, KaiSamplerType> sampler_name_map = {
                {common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_DRY),         COMMON_SAMPLER_TYPE_DRY},
                {common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K),       COMMON_SAMPLER_TYPE_TOP_K},
                {common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P),   COMMON_SAMPLER_TYPE_TYPICAL_P},
                {common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P),       COMMON_SAMPLER_TYPE_TOP_P},
                {common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P),       COMMON_SAMPLER_TYPE_MIN_P},
                {common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE},
                {common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC),         COMMON_SAMPLER_TYPE_XTC},
                {common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL),      COMMON_SAMPLER_TYPE_INFILL},
        };

        std::vector<KaiSamplerType> samplers;
        samplers.reserve(chars.size());

        for (const auto &c: chars) {
            const auto sampler = sampler_name_map.find(c);
            if (sampler != sampler_name_map.end()) {
                samplers.push_back(sampler->second);
            }
        }

        return samplers;
    }
}  // namespace kllm
